# Assumption: The point clouds must be present in the directory 'modelnet40_3cols',
# each inside the folder named after its class.
# For instance, 'airplane_0001.txt' must be present in 'modelnet40_3cols/airplane'.

from sklearn.pipeline import make_pipeline, make_union
from gtda.diagrams import PersistenceEntropy, Amplitude
from gtda.images import HeightFiltration, RadialFiltration, DilationFiltration, ErosionFiltration, SignedDistanceFiltration, DensityFiltration
from gtda.homology import CubicalPersistence
import numpy as np
import open3d as o3d
from pathlib import Path
from tqdm import tqdm

def get_label(file_path):
    items = Path(file_path).name.split('_')
    items = items[:len(items)-1]
    return "_".join(items)

direction_list = [
    [-1, 0, 0],
    [-1, 0, -1],
    [-1, 0, 1],
    [-1, -1, 0],
    [-1, -1, -1],
    [-1, -1, 1],
    [-1, 1, 0],
    [-1, 1, -1],
    [-1, 1, 1],

    [0, 0, -1],
    [0, 0, 1],
    [0, -1, 0],
    [0, -1, -1],
    [0, -1, 1],
    [0, 1, 0],
    [0, 1, -1],
    [0, 1, 1],

    [1, 0, 0],
    [1, 0, -1],
    [1, 0, 1],
    [1, -1, 0],
    [1, -1, -1],
    [1, -1, 1],
    [1, 1, 0],
    [1, 1, -1],
    [1, 1, 1],
]


def find_centers(grid_size):
    center_per_axis = 3
    center_list_local = []
    center_list_local.clear()
    x_jump = int(grid_size[0]/(1+center_per_axis))
    y_jump = int(grid_size[1]/(1+center_per_axis))
    z_jump = int(grid_size[2]/(1+center_per_axis))

    for i in range(1, 1+center_per_axis):
        for j in range(1, 1+center_per_axis):
            for k in range(1, 1+center_per_axis):
                center_list_local.append([i*x_jump, j*y_jump, k*z_jump])
    return center_list_local

def generate_tda_vector(x):
    center_list = find_centers(x.shape)

    filtration_list_h = (
        [
            HeightFiltration(direction=np.array(direction), n_jobs=1)
            for direction in direction_list
        ]
    )

    filtration_list_r = (
        [
            RadialFiltration(center=np.array(center), n_jobs=1) for center in center_list
        ]
    )

    filtration_list_DEDS= (
        [
            DilationFiltration(),
            ErosionFiltration(),
            SignedDistanceFiltration(),
            DensityFiltration()
        ]
    )

    diagram_steps = [
        [
            filtration,
            CubicalPersistence(n_jobs=1, homology_dimensions=[0,1,2]),
        ]
        for filtration in filtration_list_DEDS + filtration_list_h + filtration_list_r
    ]

    metric_list = [
        {"metric": "bottleneck", "metric_params": {}},
        {"metric": "wasserstein", "metric_params": {"p": 1}},
        {"metric": "wasserstein", "metric_params": {"p": 2}},
        {"metric": "landscape", "metric_params": {"p": 1, "n_layers": 1}},
        {"metric": "landscape", "metric_params": {"p": 1, "n_layers": 2}},
        {"metric": "landscape", "metric_params": {"p": 2, "n_layers": 1}},
        {"metric": "landscape", "metric_params": {"p": 2, "n_layers": 2}},
        {"metric": "betti", "metric_params": {"p": 1}},
        {"metric": "betti", "metric_params": {"p": 2}},
        {"metric": "heat", "metric_params": {"p": 1, "sigma": 0.15, "n_bins":20}}, #100 is slow sometimes
        {"metric": "heat", "metric_params": {"p": 2, "sigma": 0.15, "n_bins":20}}
    ]
    #
    feature_union = make_union(
        *[PersistenceEntropy(nan_fill_value=-1, n_jobs=1)]
        +
        [Amplitude(**metric, n_jobs=1) for metric in metric_list],
        verbose=False,
        n_jobs=1
    )

    tda_union = make_union(
        *[make_pipeline(*diagram_step, feature_union, verbose=False) for diagram_step in diagram_steps],
        n_jobs=16, verbose=False
    )

    return tda_union.fit_transform(x)

def uniform_dowsampling_2048(pointcloud):
    N = pointcloud.shape[0]
    indices = np.random.permutation(N)[:2048]
    downsampled_pcd = pointcloud[indices]
    return downsampled_pcd

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

def main():
    labels = []
    label_to_id = {}
    with open('modelnet40_shape_names.txt') as label_file:
        for line in label_file.readlines():
            labels.append(line.strip())
            label_to_id[line.strip()] = len(labels)-1

    training_file = 'modelnet40_train.txt'
    testing_file =  'modelnet40_test.txt'

    training_file_names = []
    with open(training_file, 'r') as fp:
        for line in fp:
            training_file_names.append('modelnet40_3cols/'+get_label(line)+'/'+line.strip()+'.txt')
    fp.close()

    testing_file_names = []
    with open(testing_file, 'r') as fp:
        for line in fp:
            testing_file_names.append('modelnet40_3cols/'+get_label(line)+'/'+line.strip()+'.txt')
    fp.close()

    print(len(training_file_names),'files in training')
    print(len(testing_file_names),'files in testing')

    training_tda_vecs = []

    for pcd_path in tqdm(training_file_names):
        points = np.loadtxt(pcd_path.strip(), delimiter=",")
        points = uniform_dowsampling_2048(points)
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)

        voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=0.05)
        voxels = voxel_grid.get_voxels()  # List of occupied voxels
        voxel_indices = np.array([voxel.grid_index for voxel in voxels])  # Shape: (N, 3)

        min_bound = np.min(points, axis=0)
        max_bound = np.max(points, axis=0)
        grid_size = 1 + np.ceil((max_bound - min_bound) / 0.05).astype(int)  # e.g., [20, 20, 20]

        binary_image = np.zeros(grid_size, dtype=np.uint8)

        for idx in voxel_indices:
            binary_image[idx[0], idx[1], idx[2]] = 1

        binary_image_4d = binary_image[np.newaxis, ...]
        tda_vector = generate_tda_vector(binary_image_4d)
        training_tda_vecs.append(tda_vector[0])

    i = 0
    with open("train-ModelNet40.txt", "w") as f:
        for vector in training_tda_vecs:
            f.write(f"{' '.join(map(str, vector))} {get_label(training_file_names[i])}\n")
            i = i + 1
    f.close()

    test_tda_vecs = []

    for pcd_path in tqdm(testing_file_names):
        points = np.loadtxt(pcd_path.strip(), delimiter=",")
        points = uniform_dowsampling_2048(points)
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)

        voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=0.05)
        voxels = voxel_grid.get_voxels()  # List of occupied voxels
        voxel_indices = np.array([voxel.grid_index for voxel in voxels])  # Shape: (N, 3)

        min_bound = np.min(points, axis=0)
        max_bound = np.max(points, axis=0)
        grid_size = 1 + np.ceil((max_bound - min_bound) / 0.05).astype(int)  # e.g., [20, 20, 20]

        binary_image = np.zeros(grid_size, dtype=np.uint8)

        for idx in voxel_indices:
            binary_image[idx[0], idx[1], idx[2]] = 1

        binary_image_4d = binary_image[np.newaxis, ...]
        tda_vector = generate_tda_vector(binary_image_4d)
        test_tda_vecs.append(tda_vector[0])

    i = 0
    with open("test-modelnet40.txt", "w") as f:
        for vector in test_tda_vecs:
            f.write(f"{' '.join(map(str, vector))} {get_label(testing_file_names[i])}\n")
            i = i + 1
    f.close()

if __name__ == "__main__":
    main()
